{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "How-to Guide: Using a PIP package for fine-tuning a BERT model",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5T_-iFRIqliG",
        "colab_type": "text"
      },
      "source": [
        "## How-to Guide: Using a PIP package for fine-tuning a BERT model\n",
        "\n",
        "Author: [Chen Chen](https://github.com/chenGitHuber)\n",
        "\n",
        "In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mY1vX5VAq4SS",
        "colab_type": "text"
      },
      "source": [
        "## License\n",
        "\n",
        "Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XV3k63bt0ihl",
        "colab_type": "text"
      },
      "source": [
        "## Learning objectives\n",
        "\n",
        "In this Colab notebook, you will learn how to fine-tune a BERT model using the TensorFlow Model Garden PIP package."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MHA-RWherfG4",
        "colab_type": "text"
      },
      "source": [
        "## Enable the GPU acceleration\n",
        "Please enable GPU for better performance.\n",
        "*   Navigate to Edit 🡒 Notebook settings\n",
        "*   Select GPU from the \"Hardware Accelerator\" drop-down list\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l2B4N5Djrs2l",
        "colab_type": "text"
      },
      "source": [
        "## Install the Model Garden PIP package\n",
        "\n",
        "Install the Model Garden PIP package (tf-models-nightly) and other necessary PIP packages."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VMYZDly6rx97",
        "colab_type": "code",
        "outputId": "146956ab-4568-4de6-c78e-cf25f115a5a8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "pip install tf-models-nightly"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting tf-models-nightly\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/bd/7c/1390d4e05d4d370e91d32dd9700d3a462dbc560c7f4e95a6477592b17def/tf_models_nightly-2.2.0.dev20200326-py2.py3-none-any.whl (710kB)\n",
            "\u001b[K     |████████████████████████████████| 716kB 2.8MB/s \n",
            "\u001b[?25hRequirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.4.1)\n",
            "Collecting opencv-python-headless\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/0b/23/5f10b30a48b218a4884bc84188c14381ac71288b210f6f8079a54f7a05e8/opencv_python_headless-4.2.0.32-cp36-cp36m-manylinux1_x86_64.whl (21.6MB)\n",
            "\u001b[K     |████████████████████████████████| 21.6MB 1.3MB/s \n",
            "\u001b[?25hRequirement already satisfied: tensorflow-hub>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7.0)\n",
            "Requirement already satisfied: typing in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.6.6)\n",
            "Collecting tensorflow-model-optimization>=0.2.1\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/8f/c4/4c3d011e432bd9c19f0323f7da7d3f783402615e4c3b5a98416c7da9cb05/tensorflow_model_optimization-0.2.1-py2.py3-none-any.whl (93kB)\n",
            "\u001b[K     |████████████████████████████████| 102kB 10.2MB/s \n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.18.2)\n",
            "Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.25.3)\n",
            "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.3.0)\n",
            "Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.21.0)\n",
            "Collecting mlperf-compliance==0.0.10\n",
            "  Downloading https://files.pythonhosted.org/packages/f4/08/f2febd8cbd5c9371f7dab311e90400d83238447ba7609b3bf0145b4cb2a2/mlperf_compliance-0.0.10-py3-none-any.whl\n",
            "Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.5.6)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.12.0)\n",
            "Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (1.7.12)\n",
            "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.8.3)\n",
            "Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (5.4.8)\n",
            "Requirement already satisfied: tensorflow-datasets in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (2.1.0)\n",
            "Collecting sentencepiece\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
            "\u001b[K     |████████████████████████████████| 1.0MB 58.3MB/s \n",
            "\u001b[?25hCollecting tf-nightly\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/39/1c/4408b4c4b0d8008a7de62162e35089d59d19cc7543cfd1b23a70121f3086/tf_nightly-2.2.0.dev20200325-cp36-cp36m-manylinux2010_x86_64.whl (516.1MB)\n",
            "\u001b[K     |████████████████████████████████| 516.1MB 21kB/s \n",
            "\u001b[?25hCollecting py-cpuinfo>=3.3.0\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/42/60/63f28a5401da733043abe7053e7d9591491b4784c4f87c339bf51215aa0a/py-cpuinfo-5.0.0.tar.gz (82kB)\n",
            "\u001b[K     |████████████████████████████████| 92kB 13.7MB/s \n",
            "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (7.0.0)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.13)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (3.2.1)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.7)\n",
            "Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (0.29.15)\n",
            "Requirement already satisfied: oauth2client>=4.1.2 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly) (4.1.3)\n",
            "Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.6.0->tf-models-nightly) (3.10.0)\n",
            "Collecting enum34~=1.1\n",
            "  Downloading https://files.pythonhosted.org/packages/63/f6/ccb1c83687756aeabbf3ca0f213508fcfb03883ff200d201b3a4c60cedcc/enum34-1.1.10-py3-none-any.whl\n",
            "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2.8.1)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly) (2018.9)\n",
            "Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (0.4.1)\n",
            "Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.0.3)\n",
            "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (1.24.3)\n",
            "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2019.11.28)\n",
            "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.0.0)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (4.38.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly) (2.21.0)\n",
            "Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.0.3)\n",
            "Requirement already satisfied: httplib2<1dev,>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (0.17.0)\n",
            "Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (3.0.1)\n",
            "Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly) (1.7.2)\n",
            "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly) (2.7.1)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.16.0)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.9.0)\n",
            "Requirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.12.1)\n",
            "Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (2.3)\n",
            "Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (19.3.0)\n",
            "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (1.1.0)\n",
            "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.21.1)\n",
            "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-nightly) (0.3.1.1)\n",
            "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.3.3)\n",
            "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.34.2)\n",
            "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.27.2)\n",
            "Collecting tb-nightly<2.3.0a0,>=2.2.0a0\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/52/b6/aa559ea9edbc6129a64ff752dbf6567bcd62fba34566defca00fdff4345e/tb_nightly-2.2.0a20200324-py3-none-any.whl (2.8MB)\n",
            "\u001b[K     |████████████████████████████████| 2.8MB 51.9MB/s \n",
            "\u001b[?25hRequirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (2.10.0)\n",
            "Collecting tf-estimator-nightly\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/bf/eb/0d6c06d1181cd9b52d7c82535073887d68d224f3dbeeb00adefd04762a9c/tf_estimator_nightly-2.3.0.dev2020032501-py2.py3-none-any.whl (455kB)\n",
            "\u001b[K     |████████████████████████████████| 460kB 53.5MB/s \n",
            "\u001b[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (0.2.0)\n",
            "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.6.3)\n",
            "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (3.2.0)\n",
            "Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tf-models-nightly) (1.1.0)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (2.4.6)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (1.1.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly) (0.10.0)\n",
            "Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.2.8)\n",
            "Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (4.0)\n",
            "Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client>=4.1.2->tf-models-nightly) (0.4.8)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.4.0->tensorflow-hub>=0.6.0->tf-models-nightly) (46.0.0)\n",
            "Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly) (1.16.0)\n",
            "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly) (1.3)\n",
            "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (2.8)\n",
            "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->kaggle>=1.3.9->tf-models-nightly) (3.0.4)\n",
            "Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly) (3.1.1)\n",
            "Requirement already satisfied: googleapis-common-protos in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets->tf-models-nightly) (1.51.0)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.2.1)\n",
            "Collecting tensorboard-plugin-wit>=1.6.0\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/41/ec/3da49289b93963bd8b32d29ed108f1809436ff3d9cd4e29c90bac4a7292f/tensorboard_plugin_wit-1.6.0.post2-py3-none-any.whl (775kB)\n",
            "\u001b[K     |████████████████████████████████| 778kB 43.8MB/s \n",
            "\u001b[?25hRequirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (0.4.1)\n",
            "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.0.0)\n",
            "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (1.3.0)\n",
            "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<2.3.0a0,>=2.2.0a0->tf-nightly->tf-models-nightly) (3.1.0)\n",
            "Building wheels for collected packages: py-cpuinfo\n",
            "  Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for py-cpuinfo: filename=py_cpuinfo-5.0.0-cp36-none-any.whl size=18684 sha256=093853bc49757be8f9facd8d38eab95b3cc2de13514dafa221ad7898725491ff\n",
            "  Stored in directory: /root/.cache/pip/wheels/01/7e/a9/b982d0fea22b7e4ae5619de949570cde5ad55420cec16e86a5\n",
            "Successfully built py-cpuinfo\n",
            "Installing collected packages: opencv-python-headless, enum34, tensorflow-model-optimization, mlperf-compliance, sentencepiece, tensorboard-plugin-wit, tb-nightly, tf-estimator-nightly, tf-nightly, py-cpuinfo, tf-models-nightly\n",
            "Successfully installed enum34-1.1.10 mlperf-compliance-0.0.10 opencv-python-headless-4.2.0.32 py-cpuinfo-5.0.0 sentencepiece-0.1.85 tb-nightly-2.2.0a20200324 tensorboard-plugin-wit-1.6.0.post2 tensorflow-model-optimization-0.2.1 tf-estimator-nightly-2.3.0.dev2020032501 tf-models-nightly-2.2.0.dev20200326 tf-nightly-2.2.0.dev20200325\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "enum"
                ]
              }
            }
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w4oyRCTji-aa",
        "colab_type": "text"
      },
      "source": [
        "## BERT Fine-tuning\n",
        "\n",
        "The following code import necessary modules for fine-tuning a BERT model on a classification task.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rV8MIX7g078-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%tensorflow_version 2.x\n",
        "import tensorflow as tf\n",
        "\n",
        "import json\n",
        "import math\n",
        "\n",
        "from official.utils.misc import distribution_utils\n",
        "from official.nlp import optimization\n",
        "from official.nlp.bert import bert_models\n",
        "from official.nlp.bert import configs as bert_configs\n",
        "from official.nlp.bert import run_classifier\n",
        "from official.modeling import activations\n",
        "from official.nlp.modeling import networks\n",
        "from official.nlp.modeling.models import bert_classifier"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DIuS8nYD08n3",
        "colab_type": "text"
      },
      "source": [
        "This section of code performs the following tasks:\n",
        "* Load data for fine-tuning\n",
        "* Fine-tune a BERT model\n",
        "* Save the fine-tuned model to a TensorFlow SavedModel file\n",
        "\n",
        "Please check [create_finetuning_data.py](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_finetuning_data.py) if you want to know how the train/eval data are created."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PAby1RTCi_1e",
        "colab_type": "code",
        "outputId": "e663e830-cc9b-4b5d-99db-4504fd66d5f3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 258
        }
      },
      "source": [
        "\n",
        "train_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_train.tf_record\"\n",
        "eval_data_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_eval.tf_record\"\n",
        "input_meta_path = \"gs://cloud-tpu-checkpoints/bert/classification/mrpc_meta_data\"\n",
        "\n",
        "bert_config_file = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json\"\n",
        "ckpt_path = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
        "\n",
        "with tf.io.gfile.GFile(input_meta_path, 'rb') as reader:\n",
        "  input_meta_data = json.loads(reader.read().decode('utf-8'))\n",
        "\n",
        "max_seq_length = input_meta_data['max_seq_length']\n",
        "num_classes = input_meta_data['num_labels']\n",
        "batch_size = 32\n",
        "eval_batch_size = 32\n",
        "train_input_fn = run_classifier.get_dataset_fn(train_data_path, max_seq_length, batch_size, is_training=True)\n",
        "eval_input_fn = run_classifier.get_dataset_fn(eval_data_path, max_seq_length, eval_batch_size, is_training=False)\n",
        "\n",
        "strategy = distribution_utils.get_distribution_strategy(\n",
        "      distribution_strategy='one_device', num_gpus=1)\n",
        "\n",
        "with strategy.scope():\n",
        "  training_dataset = train_input_fn()\n",
        "  evaluation_dataset = eval_input_fn()\n",
        "  bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)\n",
        "  classifier_model, encoder = bert_models.classifier_model(\n",
        "      bert_config, num_classes, max_seq_length)\n",
        "\n",
        "  checkpoint = tf.train.Checkpoint(model=encoder)\n",
        "  checkpoint.restore(ckpt_path).assert_consumed()\n",
        "\n",
        "  epochs = 3\n",
        "  train_data_size = input_meta_data['train_data_size']\n",
        "  eval_data_size = input_meta_data['eval_data_size']\n",
        "  steps_per_epoch = int(train_data_size / batch_size)\n",
        "  warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n",
        "  optimizer = optimization.create_optimizer(\n",
        "      2e-5, num_train_steps=steps_per_epoch * epochs, num_warmup_steps=warmup_steps)\n",
        "\n",
        "  def metric_fn():\n",
        "    return tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "        'test_accuracy', dtype=tf.float32)\n",
        "\n",
        "  classifier_model.compile(optimizer=optimizer,\n",
        "                           loss=run_classifier.get_loss_fn(num_classes=2),\n",
        "                           metrics=[metric_fn()])\n",
        "  classifier_model.fit(\n",
        "        x=training_dataset,\n",
        "        validation_data=evaluation_dataset,\n",
        "        steps_per_epoch=steps_per_epoch,\n",
        "        epochs=epochs,\n",
        "        validation_steps=int(eval_data_size / eval_batch_size))\n",
        "\n",
        "  classifier_model.save('/tmp/saved_model', include_optimizer=False, save_format='tf')"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_mask.\n",
            "Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n",
            "The tensor that caused the issue was: input_mask:0\n",
            "WARNING:tensorflow:BertClassifier inputs must come from `tf.keras.Input` (thus holding past layer metadata), they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to \"bert_classifier\" was not an Input tensor, it was generated by layer input_type_ids.\n",
            "Note that input tensors are instantiated via `tensor = tf.keras.Input(shape)`.\n",
            "The tensor that caused the issue was: input_type_ids:0\n",
            "Epoch 1/3\n",
            "114/114 [==============================] - 96s 840ms/step - loss: 0.5932 - test_accuracy: 0.6960 - val_loss: 0.5083 - val_test_accuracy: 0.7604\n",
            "Epoch 2/3\n",
            "114/114 [==============================] - 100s 878ms/step - loss: 0.4225 - test_accuracy: 0.8183 - val_loss: 0.4020 - val_test_accuracy: 0.8438\n",
            "Epoch 3/3\n",
            "114/114 [==============================] - 100s 880ms/step - loss: 0.2482 - test_accuracy: 0.9134 - val_loss: 0.4065 - val_test_accuracy: 0.8151\n",
            "INFO:tensorflow:Assets written to: /tmp/saved_model/assets\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}